import os
from torchvision.datasets import MNIST
import torch

# 设置下载路径为当前目录下的mnist_data文件夹
download_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "mnist_data")
os.makedirs(download_path, exist_ok=True)

print(f"正在下载MNIST数据集到 {download_path}...")

# 下载训练集和测试集
train_dataset = MNIST(download_path, train=True, download=True)
test_dataset = MNIST(download_path, train=False, download=True)

print("MNIST数据集下载完成。")
print(f"文件位置: {download_path}")
print("文件列表:")
for root, dirs, files in os.walk(download_path):
    for file in files:
        print(os.path.join(root, file))
